{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://mirrors.tencentyun.com/pypi/simple\n",
      "Requirement already satisfied: pip in ./.conda/lib/python3.11/site-packages (24.2)\n",
      "Collecting pip\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/ef/7d/500c9ad20238fcfcb4cb9243eede163594d7020ce87bd9610c9e02771876/pip-24.3.1-py3-none-any.whl (1.8 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: pip\n",
      "  Attempting uninstall: pip\n",
      "    Found existing installation: pip 24.2\n",
      "    Uninstalling pip-24.2:\n",
      "      Successfully uninstalled pip-24.2\n",
      "Successfully installed pip-24.3.1\n",
      "Looking in indexes: http://mirrors.tencentyun.com/pypi/simple\n",
      "Collecting gym\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/ab/b1/eb05a423eb801ab7d0715d6a3b28d92589e30b437052553df19ca2087240/gym-0.26.2.tar.gz (721 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m721.7/721.7 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h  Installing build dependencies ... \u001b[?25ldone\n",
      "\u001b[?25h  Getting requirements to build wheel ... \u001b[?25ldone\n",
      "\u001b[?25h  Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
      "\u001b[?25hCollecting numpy>=1.18.0 (from gym)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/7a/f0/80811e836484262b236c684a75dfc4ba0424bc670e765afaa911468d9f39/numpy-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.3 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.3/16.3 MB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting cloudpickle>=1.2.0 (from gym)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/48/41/e1d85ca3cab0b674e277c8c4f678cf66a91cd2cecf93df94353a606fe0db/cloudpickle-3.1.0-py3-none-any.whl (22 kB)\n",
      "Collecting gym_notices>=0.0.4 (from gym)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/25/26/d786c6bec30fe6110fd3d22c9a273a2a0e56c0b73b93e25ea1af5a53243b/gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\n",
      "Building wheels for collected packages: gym\n",
      "  Building wheel for gym (pyproject.toml) ... \u001b[?25ldone\n",
      "\u001b[?25h  Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827625 sha256=e7026b5fc03e5c81cfddaf9db06b42270c063389fdda74e80f24ae5789f4861e\n",
      "  Stored in directory: /home/ubuntu/.cache/pip/wheels/9e/54/7b/2af703e473b988b578e6289bd91c13a2b70b240c08d0c06839\n",
      "Successfully built gym\n",
      "Installing collected packages: gym_notices, numpy, cloudpickle, gym\n",
      "Successfully installed cloudpickle-3.1.0 gym-0.26.2 gym_notices-0.0.8 numpy-2.1.3\n",
      "Looking in indexes: http://mirrors.tencentyun.com/pypi/simple\n",
      "Collecting tqdm\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl (78 kB)\n",
      "Installing collected packages: tqdm\n",
      "Successfully installed tqdm-4.67.1\n",
      "Looking in indexes: http://mirrors.tencentyun.com/pypi/simple\n",
      "Collecting nes-py\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/04/6b/51c9efe4fc67a9311b1125a19ec7a6176c8f0b8334e418d7f5bc53aeec56/nes_py-8.2.1.tar.gz (77 kB)\n",
      "  Preparing metadata (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25hRequirement already satisfied: gym>=0.17.2 in ./.conda/lib/python3.11/site-packages (from nes-py) (0.26.2)\n",
      "Requirement already satisfied: numpy>=1.18.5 in ./.conda/lib/python3.11/site-packages (from nes-py) (2.1.3)\n",
      "Collecting pyglet<=1.5.21,>=1.4.0 (from nes-py)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/0b/b7/7736d7638d91b354700dc9bae447728c514c4bc6ecb4c0f7e0cd9a390f20/pyglet-1.5.21-py3-none-any.whl (1.1 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: tqdm>=4.48.2 in ./.conda/lib/python3.11/site-packages (from nes-py) (4.67.1)\n",
      "Requirement already satisfied: cloudpickle>=1.2.0 in ./.conda/lib/python3.11/site-packages (from gym>=0.17.2->nes-py) (3.1.0)\n",
      "Requirement already satisfied: gym_notices>=0.0.4 in ./.conda/lib/python3.11/site-packages (from gym>=0.17.2->nes-py) (0.0.8)\n",
      "Building wheels for collected packages: nes-py\n",
      "  Building wheel for nes-py (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25h  Created wheel for nes-py: filename=nes_py-8.2.1-cp311-cp311-linux_x86_64.whl size=49477 sha256=3a046bf6f908a6aaee324102e669b250510f2d37c6ce922424194bd7b46792eb\n",
      "  Stored in directory: /home/ubuntu/.cache/pip/wheels/c3/3e/05/305012edfda4d5cdbf214ef8a0173d60cbf4efdc8f9cb1d3ce\n",
      "Successfully built nes-py\n",
      "Installing collected packages: pyglet, nes-py\n",
      "Successfully installed nes-py-8.2.1 pyglet-1.5.21\n",
      "Looking in indexes: http://mirrors.tencentyun.com/pypi/simple\n",
      "Collecting gym-super-mario-bros\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/0c/4f/0951ba1480c6f874a5e3b043f20d06fc1951cc03844e8bcfa6208291d712/gym_super_mario_bros-7.4.0-py3-none-any.whl (199 kB)\n",
      "Requirement already satisfied: nes-py>=8.1.4 in ./.conda/lib/python3.11/site-packages (from gym-super-mario-bros) (8.2.1)\n",
      "Requirement already satisfied: gym>=0.17.2 in ./.conda/lib/python3.11/site-packages (from nes-py>=8.1.4->gym-super-mario-bros) (0.26.2)\n",
      "Requirement already satisfied: numpy>=1.18.5 in ./.conda/lib/python3.11/site-packages (from nes-py>=8.1.4->gym-super-mario-bros) (2.1.3)\n",
      "Requirement already satisfied: pyglet<=1.5.21,>=1.4.0 in ./.conda/lib/python3.11/site-packages (from nes-py>=8.1.4->gym-super-mario-bros) (1.5.21)\n",
      "Requirement already satisfied: tqdm>=4.48.2 in ./.conda/lib/python3.11/site-packages (from nes-py>=8.1.4->gym-super-mario-bros) (4.67.1)\n",
      "Requirement already satisfied: cloudpickle>=1.2.0 in ./.conda/lib/python3.11/site-packages (from gym>=0.17.2->nes-py>=8.1.4->gym-super-mario-bros) (3.1.0)\n",
      "Requirement already satisfied: gym_notices>=0.0.4 in ./.conda/lib/python3.11/site-packages (from gym>=0.17.2->nes-py>=8.1.4->gym-super-mario-bros) (0.0.8)\n",
      "Installing collected packages: gym-super-mario-bros\n",
      "Successfully installed gym-super-mario-bros-7.4.0\n",
      "Looking in indexes: https://download.pytorch.org/whl/cu121\n",
      "Collecting torch\n",
      "  Downloading https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp311-cp311-linux_x86_64.whl (780.5 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m780.5/780.5 MB\u001b[0m \u001b[31m10.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:02\u001b[0m\n",
      "\u001b[?25hCollecting torchvision\n",
      "  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp311-cp311-linux_x86_64.whl (7.3 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.3/7.3 MB\u001b[0m \u001b[31m21.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting torchaudio\n",
      "  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.5.1%2Bcu121-cp311-cp311-linux_x86_64.whl (3.4 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m42.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting filelock (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/filelock-3.13.1-py3-none-any.whl (11 kB)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in ./.conda/lib/python3.11/site-packages (from torch) (4.12.2)\n",
      "Collecting networkx (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/networkx-3.2.1-py3-none-any.whl (1.6 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m99.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting jinja2 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/Jinja2-3.1.3-py3-none-any.whl (133 kB)\n",
      "Collecting fsspec (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/fsspec-2024.2.0-py3-none-any.whl (170 kB)\n",
      "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m22.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m36.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:02\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m11.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:02\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m12.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m12.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-nccl-cu12==2.21.5 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl (188.7 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m188.7/188.7 MB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
      "Collecting triton==3.1.0 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.5/209.5 MB\u001b[0m \u001b[31m12.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting sympy==1.13.1 (from torch)\n",
      "  Downloading https://download.pytorch.org/whl/sympy-1.13.1-py3-none-any.whl (6.2 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.2/6.2 MB\u001b[0m \u001b[31m112.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n",
      "  Downloading https://download.pytorch.org/whl/cu121/nvidia_nvjitlink_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (19.8 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m19.8/19.8 MB\u001b[0m \u001b[31m22.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch)\n",
      "  Downloading https://download.pytorch.org/whl/mpmath-1.3.0-py3-none-any.whl (536 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m536.2/536.2 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: numpy in ./.conda/lib/python3.11/site-packages (from torchvision) (2.1.3)\n",
      "Collecting pillow!=8.3.*,>=5.3.0 (from torchvision)\n",
      "  Downloading https://download.pytorch.org/whl/pillow-10.2.0-cp311-cp311-manylinux_2_28_x86_64.whl (4.5 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m134.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting MarkupSafe>=2.0 (from jinja2->torch)\n",
      "  Downloading https://download.pytorch.org/whl/MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (28 kB)\n",
      "Installing collected packages: mpmath, sympy, pillow, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, networkx, MarkupSafe, fsspec, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jinja2, nvidia-cusolver-cu12, torch, torchvision, torchaudio\n",
      "Successfully installed MarkupSafe-2.1.5 filelock-3.13.1 fsspec-2024.2.0 jinja2-3.1.3 mpmath-1.3.0 networkx-3.2.1 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.21.5 nvidia-nvjitlink-cu12-12.1.105 nvidia-nvtx-cu12-12.1.105 pillow-10.2.0 sympy-1.13.1 torch-2.5.1+cu121 torchaudio-2.5.1+cu121 torchvision-0.20.1+cu121 triton-3.1.0\n"
     ]
    }
   ],
   "source": [
    "!pip install -U pip\n",
    "!pip install gym\n",
    "!pip install tqdm\n",
    "!pip install nes-py\n",
    "!pip install gym-super-mario-bros\n",
    "!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://mirrors.tencentyun.com/pypi/simple\n",
      "Collecting jupyter_contrib_nbextensions\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/50/91/78cc4362611dbde2b0cd068204aaf1b8899d0459c50d8ff9daca8c069791/jupyter_contrib_nbextensions-0.7.0.tar.gz (23.5 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.5/23.5 MB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25hCollecting ipython_genutils (from jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/fa/bc/9bd3b5c2b4774d5f33b2d544f1460be9df7df2fe42f352135381c347c69a/ipython_genutils-0.2.0-py2.py3-none-any.whl (26 kB)\n",
      "Collecting jupyter_contrib_core>=0.3.3 (from jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/50/94/0d37e5b49ea1c8bf204c46f9b0257c1f3319a4ab88acbd401da2cab25e55/jupyter_contrib_core-0.4.2.tar.gz (17 kB)\n",
      "  Preparing metadata (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25hRequirement already satisfied: jupyter_core in ./.conda/lib/python3.11/site-packages (from jupyter_contrib_nbextensions) (5.7.2)\n",
      "Collecting jupyter_highlight_selected_word>=0.1.1 (from jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/50/d7/19ab7cfd60bf268d2abbacc52d4295a40f52d74dfc0d938e4761ee5e598b/jupyter_highlight_selected_word-0.2.0-py2.py3-none-any.whl (11 kB)\n",
      "Collecting jupyter_nbextensions_configurator>=0.4.0 (from jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/05/fe/cffb14a4fbb43cf276aa3047e42c3f9ecfda851ba3c466295401f6b1e085/jupyter_nbextensions_configurator-0.6.4-py2.py3-none-any.whl (466 kB)\n",
      "Collecting nbconvert>=6.0 (from jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/b8/bb/bb5b6a515d1584aa2fd89965b11db6632e4bdc69495a52374bcc36e56cfa/nbconvert-7.16.4-py3-none-any.whl (257 kB)\n",
      "Collecting notebook>=6.0 (from jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/46/77/53732fbf48196af9e51c2a61833471021c1d77d335d57b96ee3588c0c53d/notebook-7.2.2-py3-none-any.whl (5.0 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.0/5.0 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: tornado in ./.conda/lib/python3.11/site-packages (from jupyter_contrib_nbextensions) (6.4.2)\n",
      "Requirement already satisfied: traitlets>=4.1 in ./.conda/lib/python3.11/site-packages (from jupyter_contrib_nbextensions) (5.14.3)\n",
      "Collecting lxml (from jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/42/07/b29571a58a3a80681722ea8ed0ba569211d9bb8531ad49b5cacf6d409185/lxml-5.3.0-cp311-cp311-manylinux_2_28_x86_64.whl (5.0 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.0/5.0 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: setuptools in ./.conda/lib/python3.11/site-packages (from jupyter_contrib_core>=0.3.3->jupyter_contrib_nbextensions) (75.1.0)\n",
      "Collecting jupyter-server (from jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/57/e1/085edea6187a127ca8ea053eb01f4e1792d778b4d192c74d32eb6730fed6/jupyter_server-2.14.2-py3-none-any.whl (383 kB)\n",
      "Collecting pyyaml (from jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/75/e4/2c27590dfc9992f73aabbeb9241ae20220bd9452df27483b6e56d3975cc5/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (762 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m763.0/763.0 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting beautifulsoup4 (from nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/b1/fe/e8c672695b37eecc5cbf43e1d0638d88d66ba3a44c4d321c796f4e59167f/beautifulsoup4-4.12.3-py3-none-any.whl (147 kB)\n",
      "Collecting bleach!=5.0.0 (from nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/fc/55/96142937f66150805c25c4d0f31ee4132fd33497753400734f9dfdcbdc66/bleach-6.2.0-py3-none-any.whl (163 kB)\n",
      "Collecting defusedxml (from nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl (25 kB)\n",
      "Requirement already satisfied: jinja2>=3.0 in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.0->jupyter_contrib_nbextensions) (3.1.3)\n",
      "Collecting jupyterlab-pygments (from nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/b1/dd/ead9d8ea85bf202d90cc513b533f9c363121c7792674f78e0d8a854b63b4/jupyterlab_pygments-0.3.0-py3-none-any.whl (15 kB)\n",
      "Requirement already satisfied: markupsafe>=2.0 in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.0->jupyter_contrib_nbextensions) (2.1.5)\n",
      "Collecting mistune<4,>=2.0.3 (from nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/f0/74/c95adcdf032956d9ef6c89a9b8a5152bf73915f8c633f3e3d88d06bd699c/mistune-3.0.2-py3-none-any.whl (47 kB)\n",
      "Collecting nbclient>=0.5.0 (from nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/26/1a/ed6d1299b1a00c1af4a033fdee565f533926d819e084caf0d2832f6f87c6/nbclient-0.10.1-py3-none-any.whl (25 kB)\n",
      "Collecting nbformat>=5.7 (from nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/a9/82/0340caa499416c78e5d8f5f05947ae4bc3cba53c9f038ab6e9ed964e22f1/nbformat-5.10.4-py3-none-any.whl (78 kB)\n",
      "Requirement already satisfied: packaging in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.0->jupyter_contrib_nbextensions) (24.2)\n",
      "Collecting pandocfilters>=1.4.1 (from nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/ef/af/4fbc8cab944db5d21b7e2a5b8e9211a03a79852b1157e2c102fcc61ac440/pandocfilters-1.5.1-py2.py3-none-any.whl (8.7 kB)\n",
      "Requirement already satisfied: pygments>=2.4.1 in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.0->jupyter_contrib_nbextensions) (2.18.0)\n",
      "Collecting tinycss2 (from nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl (26 kB)\n",
      "Requirement already satisfied: platformdirs>=2.5 in ./.conda/lib/python3.11/site-packages (from jupyter_core->jupyter_contrib_nbextensions) (4.3.6)\n",
      "Collecting jupyterlab-server<3,>=2.27.1 (from notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/54/09/2032e7d15c544a0e3cd831c51d77a8ca57f7555b2e1b2922142eddb02a84/jupyterlab_server-2.27.3-py3-none-any.whl (59 kB)\n",
      "Collecting jupyterlab<4.3,>=4.2.0 (from notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/f0/04/853abc46fef36afd4e5f9a4fd1fbc1b477f910a29bb71711b6653098b703/jupyterlab-4.2.6-py3-none-any.whl (11.6 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.6/11.6 MB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting notebook-shim<0.3,>=0.2 (from notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/f9/33/bd5b9137445ea4b680023eb0469b2bb969d61303dedb2aac6560ff3d14a1/notebook_shim-0.2.4-py3-none-any.whl (13 kB)\n",
      "Collecting webencodings (from bleach!=5.0.0->nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl (11 kB)\n",
      "Collecting anyio>=3.1.0 (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/e4/f5/f2b75d2fc6f1a260f340f0e7c6a060f4dd2961cc16884ed851b0d18da06a/anyio-4.6.2.post1-py3-none-any.whl (90 kB)\n",
      "Collecting argon2-cffi>=21.1 (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/a4/6a/e8a041599e78b6b3752da48000b14c8d1e8a04ded09c88c714ba047f34f5/argon2_cffi-23.1.0-py3-none-any.whl (15 kB)\n",
      "Requirement already satisfied: jupyter-client>=7.4.4 in ./.conda/lib/python3.11/site-packages (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions) (8.6.3)\n",
      "Collecting jupyter-events>=0.9.0 (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/a5/94/059180ea70a9a326e1815176b2370da56376da347a796f8c4f0b830208ef/jupyter_events-0.10.0-py3-none-any.whl (18 kB)\n",
      "Collecting jupyter-server-terminals>=0.4.4 (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/07/2d/2b32cdbe8d2a602f697a649798554e4f072115438e92249624e532e8aca6/jupyter_server_terminals-0.5.3-py3-none-any.whl (13 kB)\n",
      "Collecting overrides>=5.0 (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/2c/ab/fc8290c6a4c722e5514d80f62b2dc4c4df1a68a41d1364e625c35990fcf3/overrides-7.7.0-py3-none-any.whl (17 kB)\n",
      "Collecting prometheus-client>=0.9 (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/84/2d/46ed6436849c2c88228c3111865f44311cff784b4aabcdef4ea2545dbc3d/prometheus_client-0.21.0-py3-none-any.whl (54 kB)\n",
      "Requirement already satisfied: pyzmq>=24 in ./.conda/lib/python3.11/site-packages (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions) (26.2.0)\n",
      "Collecting send2trash>=1.8.2 (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/40/b0/4562db6223154aa4e22f939003cb92514c79f3d4dccca3444253fd17f902/Send2Trash-1.8.3-py3-none-any.whl (18 kB)\n",
      "Collecting terminado>=0.8.3 (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/6a/9e/2064975477fdc887e47ad42157e214526dcad8f317a948dee17e1659a62f/terminado-0.18.1-py3-none-any.whl (14 kB)\n",
      "Collecting websocket-client>=1.7 (from jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/5a/84/44687a29792a70e111c5c477230a72c4b957d88d16141199bf9acb7537a3/websocket_client-1.8.0-py3-none-any.whl (58 kB)\n",
      "Collecting async-lru>=1.0.0 (from jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/fa/9f/3c3503693386c4b0f245eaf5ca6198e3b28879ca0a40bde6b0e319793453/async_lru-2.0.4-py3-none-any.whl (6.1 kB)\n",
      "Collecting httpx>=0.25.0 (from jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/8f/fb/a19866137577ba60c6d8b69498dc36be479b13ba454f691348ddf428f185/httpx-0.28.0-py3-none-any.whl (73 kB)\n",
      "Requirement already satisfied: ipykernel>=6.5.0 in ./.conda/lib/python3.11/site-packages (from jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (6.29.5)\n",
      "Collecting jupyter-lsp>=2.0.0 (from jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/07/e0/7bd7cff65594fd9936e2f9385701e44574fc7d721331ff676ce440b14100/jupyter_lsp-2.2.5-py3-none-any.whl (69 kB)\n",
      "Collecting babel>=2.10 (from jupyterlab-server<3,>=2.27.1->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/ed/20/bc79bc575ba2e2a7f70e8a1155618bb1301eaa5132a8271373a6903f73f8/babel-2.16.0-py3-none-any.whl (9.6 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.6/9.6 MB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting json5>=0.9.0 (from jupyterlab-server<3,>=2.27.1->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/aa/42/797895b952b682c3dafe23b1834507ee7f02f4d6299b65aaa61425763278/json5-0.10.0-py3-none-any.whl (34 kB)\n",
      "Collecting jsonschema>=4.18.0 (from jupyterlab-server<3,>=2.27.1->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl (88 kB)\n",
      "Collecting requests>=2.31 (from jupyterlab-server<3,>=2.27.1->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl (64 kB)\n",
      "Collecting fastjsonschema>=2.15 (from nbformat>=5.7->nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/3f/3a/404a60bb9789ce4daecbb4ec780bee1c46d2ea5258cf689b7ab63acefd6f/fastjsonschema-2.21.0-py3-none-any.whl (23 kB)\n",
      "Collecting soupsieve>1.2 (from beautifulsoup4->nbconvert>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/d1/c2/fe97d779f3ef3b15f05c94a2f1e3d21732574ed441687474db9d342a7315/soupsieve-2.6-py3-none-any.whl (36 kB)\n",
      "Collecting idna>=2.8 (from anyio>=3.1.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl (70 kB)\n",
      "Collecting sniffio>=1.1 (from anyio>=3.1.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl (10 kB)\n",
      "Collecting argon2-cffi-bindings (from argon2-cffi>=21.1->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/ec/f7/378254e6dd7ae6f31fe40c8649eea7d4832a42243acaf0f1fff9083b2bed/argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (86 kB)\n",
      "Collecting certifi (from httpx>=0.25.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/12/90/3c9ff0512038035f59d279fddeb79f5f1eccd8859f06d6163c58798b9487/certifi-2024.8.30-py3-none-any.whl (167 kB)\n",
      "Collecting httpcore==1.* (from httpx>=0.25.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl (78 kB)\n",
      "Collecting h11<0.15,>=0.13 (from httpcore==1.*->httpx>=0.25.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl (58 kB)\n",
      "Requirement already satisfied: comm>=0.1.1 in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (0.2.2)\n",
      "Requirement already satisfied: debugpy>=1.6.5 in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (1.8.9)\n",
      "Requirement already satisfied: ipython>=7.23.1 in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (8.30.0)\n",
      "Requirement already satisfied: matplotlib-inline>=0.1 in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (0.1.7)\n",
      "Requirement already satisfied: nest-asyncio in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (1.6.0)\n",
      "Requirement already satisfied: psutil in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (6.1.0)\n",
      "Collecting attrs>=22.2.0 (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/6a/21/5b6702a7f963e95456c0de2d495f67bf5fd62840ac655dc451586d23d39a/attrs-24.2.0-py3-none-any.whl (63 kB)\n",
      "Collecting jsonschema-specifications>=2023.03.6 (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/d1/0f/8910b19ac0670a0f80ce1008e5e751c4a57e14d2c4c13a482aa6079fa9d6/jsonschema_specifications-2024.10.1-py3-none-any.whl (18 kB)\n",
      "Collecting referencing>=0.28.4 (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/b7/59/2056f61236782a2c86b33906c025d4f4a0b17be0161b63b70fd9e8775d36/referencing-0.35.1-py3-none-any.whl (26 kB)\n",
      "Collecting rpds-py>=0.7.1 (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/e1/fd/f1fd7e77fef8e5a442ce7fd80ba957730877515fe18d7195f646408a60ce/rpds_py-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (360 kB)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in ./.conda/lib/python3.11/site-packages (from jupyter-client>=7.4.4->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions) (2.9.0.post0)\n",
      "Collecting python-json-logger>=2.0.4 (from jupyter-events>=0.9.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/35/a6/145655273568ee78a581e734cf35beb9e33a370b29c5d3c8fee3744de29f/python_json_logger-2.0.7-py3-none-any.whl (8.1 kB)\n",
      "Collecting rfc3339-validator (from jupyter-events>=0.9.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/7b/44/4e421b96b67b2daff264473f7465db72fbdf36a07e05494f50300cc7b0c6/rfc3339_validator-0.1.4-py2.py3-none-any.whl (3.5 kB)\n",
      "Collecting rfc3986-validator>=0.1.1 (from jupyter-events>=0.9.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/9e/51/17023c0f8f1869d8806b979a2bffa3f861f26a3f1a66b094288323fba52f/rfc3986_validator-0.1.1-py2.py3-none-any.whl (4.2 kB)\n",
      "Collecting charset-normalizer<4,>=2 (from requests>=2.31->jupyterlab-server<3,>=2.27.1->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/eb/5b/6f10bad0f6461fa272bfbbdf5d0023b5fb9bc6217c92bf068fa5a99820f5/charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (142 kB)\n",
      "Collecting urllib3<3,>=1.21.1 (from requests>=2.31->jupyterlab-server<3,>=2.27.1->notebook>=6.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/ce/d9/5f4c13cecde62396b0d3fe530a50ccea91e7dfc1ccf0e09c228841bb5ba8/urllib3-2.2.3-py3-none-any.whl (126 kB)\n",
      "Requirement already satisfied: ptyprocess in ./.conda/lib/python3.11/site-packages (from terminado>=0.8.3->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions) (0.7.0)\n",
      "Requirement already satisfied: decorator in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (5.1.1)\n",
      "Requirement already satisfied: jedi>=0.16 in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (0.19.2)\n",
      "Requirement already satisfied: pexpect>4.3 in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (4.9.0)\n",
      "Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (3.0.48)\n",
      "Requirement already satisfied: stack_data in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (0.6.2)\n",
      "Requirement already satisfied: typing_extensions>=4.6 in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (4.12.2)\n",
      "Collecting fqdn (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/cf/58/8acf1b3e91c58313ce5cb67df61001fc9dcd21be4fadb76c1a2d540e09ed/fqdn-1.5.1-py3-none-any.whl (9.1 kB)\n",
      "Collecting isoduration (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/7b/55/e5326141505c5d5e34c5e0935d2908a74e4561eca44108fbfb9c13d2911a/isoduration-20.11.0-py3-none-any.whl (11 kB)\n",
      "Collecting jsonpointer>1.13 (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl (7.6 kB)\n",
      "Collecting uri-template (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/e7/00/3fca040d7cf8a32776d3d81a00c8ee7457e00f80c649f1e4a863c8321ae9/uri_template-1.3.0-py3-none-any.whl (11 kB)\n",
      "Collecting webcolors>=24.6.0 (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/60/e8/c0e05e4684d13459f93d312077a9a2efbe04d59c393bc2b8802248c908d4/webcolors-24.11.1-py3-none-any.whl (14 kB)\n",
      "Requirement already satisfied: six>=1.5 in ./.conda/lib/python3.11/site-packages (from python-dateutil>=2.8.2->jupyter-client>=7.4.4->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions) (1.16.0)\n",
      "Collecting cffi>=1.0.1 (from argon2-cffi-bindings->argon2-cffi>=21.1->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/ff/6b/d45873c5e0242196f042d555526f92aa9e0c32355a1be1ff8c27f077fd37/cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (467 kB)\n",
      "Collecting pycparser (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi>=21.1->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl (117 kB)\n",
      "Requirement already satisfied: parso<0.9.0,>=0.8.4 in ./.conda/lib/python3.11/site-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (0.8.4)\n",
      "Requirement already satisfied: wcwidth in ./.conda/lib/python3.11/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (0.2.13)\n",
      "Collecting arrow>=0.15.0 (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/f8/ed/e97229a566617f2ae958a6b13e7cc0f585470eac730a73e9e82c32a3cdd2/arrow-1.3.0-py3-none-any.whl (66 kB)\n",
      "Requirement already satisfied: executing>=1.2.0 in ./.conda/lib/python3.11/site-packages (from stack_data->ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (2.1.0)\n",
      "Requirement already satisfied: asttokens>=2.1.0 in ./.conda/lib/python3.11/site-packages (from stack_data->ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (2.4.1)\n",
      "Requirement already satisfied: pure-eval in ./.conda/lib/python3.11/site-packages (from stack_data->ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=6.0->jupyter_contrib_nbextensions) (0.2.3)\n",
      "Collecting types-python-dateutil>=2.8.10 (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server->jupyter_nbextensions_configurator>=0.4.0->jupyter_contrib_nbextensions)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/35/d6/ba5f61958f358028f2e2ba1b8e225b8e263053bd57d3a79e2d2db64c807b/types_python_dateutil-2.9.0.20241003-py3-none-any.whl (9.7 kB)\n",
      "Building wheels for collected packages: jupyter_contrib_nbextensions, jupyter_contrib_core\n",
      "  Building wheel for jupyter_contrib_nbextensions (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25h  Created wheel for jupyter_contrib_nbextensions: filename=jupyter_contrib_nbextensions-0.7.0-py2.py3-none-any.whl size=23428768 sha256=81c0bd79913232ff554358ecea32e5802c09481cf8a26909f240788c41953a84\n",
      "  Stored in directory: /home/ubuntu/.cache/pip/wheels/24/b0/3a/4748a7f09b38540647f3e11218658d5d584623fb0c02eb2e0f\n",
      "  Building wheel for jupyter_contrib_core (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25h  Created wheel for jupyter_contrib_core: filename=jupyter_contrib_core-0.4.2-py2.py3-none-any.whl size=17475 sha256=f024cdb7b3b6f83017dd5b5c3ca19677f3583e35bd2a74ce56a50dbf17aba4cf\n",
      "  Stored in directory: /home/ubuntu/.cache/pip/wheels/65/0a/f2/aeda1ac92a0a63c7baad7d18735c8e26c513c102d543d2d8d3\n",
      "Successfully built jupyter_contrib_nbextensions jupyter_contrib_core\n",
      "Installing collected packages: webencodings, jupyter_highlight_selected_word, ipython_genutils, fastjsonschema, websocket-client, webcolors, urllib3, uri-template, types-python-dateutil, tinycss2, terminado, soupsieve, sniffio, send2trash, rpds-py, rfc3986-validator, rfc3339-validator, pyyaml, python-json-logger, pycparser, prometheus-client, pandocfilters, overrides, mistune, lxml, jupyterlab-pygments, jsonpointer, json5, idna, h11, fqdn, defusedxml, charset-normalizer, certifi, bleach, babel, attrs, async-lru, requests, referencing, jupyter-server-terminals, httpcore, cffi, beautifulsoup4, arrow, anyio, jsonschema-specifications, isoduration, httpx, argon2-cffi-bindings, jsonschema, argon2-cffi, nbformat, nbclient, jupyter-events, nbconvert, jupyter-server, notebook-shim, jupyterlab-server, jupyter-lsp, jupyterlab, notebook, jupyter_contrib_core, jupyter_nbextensions_configurator, jupyter_contrib_nbextensions\n",
      "Successfully installed anyio-4.6.2.post1 argon2-cffi-23.1.0 argon2-cffi-bindings-21.2.0 arrow-1.3.0 async-lru-2.0.4 attrs-24.2.0 babel-2.16.0 beautifulsoup4-4.12.3 bleach-6.2.0 certifi-2024.8.30 cffi-1.17.1 charset-normalizer-3.4.0 defusedxml-0.7.1 fastjsonschema-2.21.0 fqdn-1.5.1 h11-0.14.0 httpcore-1.0.7 httpx-0.28.0 idna-3.10 ipython_genutils-0.2.0 isoduration-20.11.0 json5-0.10.0 jsonpointer-3.0.0 jsonschema-4.23.0 jsonschema-specifications-2024.10.1 jupyter-events-0.10.0 jupyter-lsp-2.2.5 jupyter-server-2.14.2 jupyter-server-terminals-0.5.3 jupyter_contrib_core-0.4.2 jupyter_contrib_nbextensions-0.7.0 jupyter_highlight_selected_word-0.2.0 jupyter_nbextensions_configurator-0.6.4 jupyterlab-4.2.6 jupyterlab-pygments-0.3.0 jupyterlab-server-2.27.3 lxml-5.3.0 mistune-3.0.2 nbclient-0.10.1 nbconvert-7.16.4 nbformat-5.10.4 notebook-7.2.2 notebook-shim-0.2.4 overrides-7.7.0 pandocfilters-1.5.1 prometheus-client-0.21.0 pycparser-2.22 python-json-logger-2.0.7 pyyaml-6.0.2 referencing-0.35.1 requests-2.32.3 rfc3339-validator-0.1.4 rfc3986-validator-0.1.1 rpds-py-0.21.0 send2trash-1.8.3 sniffio-1.3.1 soupsieve-2.6 terminado-0.18.1 tinycss2-1.4.0 types-python-dateutil-2.9.0.20241003 uri-template-1.3.0 urllib3-2.2.3 webcolors-24.11.1 webencodings-0.5.1 websocket-client-1.8.0\n",
      "Looking in indexes: http://mirrors.tencentyun.com/pypi/simple\n",
      "Requirement already satisfied: jupyterlab in ./.conda/lib/python3.11/site-packages (4.2.6)\n",
      "Requirement already satisfied: async-lru>=1.0.0 in ./.conda/lib/python3.11/site-packages (from jupyterlab) (2.0.4)\n",
      "Requirement already satisfied: httpx>=0.25.0 in ./.conda/lib/python3.11/site-packages (from jupyterlab) (0.28.0)\n",
      "Requirement already satisfied: ipykernel>=6.5.0 in ./.conda/lib/python3.11/site-packages (from jupyterlab) (6.29.5)\n",
      "Requirement already satisfied: jinja2>=3.0.3 in ./.conda/lib/python3.11/site-packages (from jupyterlab) (3.1.3)\n",
      "Requirement already satisfied: jupyter-core in ./.conda/lib/python3.11/site-packages (from jupyterlab) (5.7.2)\n",
      "Requirement already satisfied: jupyter-lsp>=2.0.0 in ./.conda/lib/python3.11/site-packages (from jupyterlab) (2.2.5)\n",
      "Requirement already satisfied: jupyter-server<3,>=2.4.0 in ./.conda/lib/python3.11/site-packages (from jupyterlab) (2.14.2)\n",
      "Requirement already satisfied: jupyterlab-server<3,>=2.27.1 in ./.conda/lib/python3.11/site-packages (from jupyterlab) (2.27.3)\n",
      "Requirement already satisfied: notebook-shim>=0.2 in ./.conda/lib/python3.11/site-packages (from jupyterlab) (0.2.4)\n",
      "Requirement already satisfied: packaging in ./.conda/lib/python3.11/site-packages (from jupyterlab) (24.2)\n",
      "Requirement already satisfied: setuptools>=40.1.0 in ./.conda/lib/python3.11/site-packages (from jupyterlab) (75.1.0)\n",
      "Requirement already satisfied: tornado>=6.2.0 in ./.conda/lib/python3.11/site-packages (from jupyterlab) (6.4.2)\n",
      "Requirement already satisfied: traitlets in ./.conda/lib/python3.11/site-packages (from jupyterlab) (5.14.3)\n",
      "Requirement already satisfied: anyio in ./.conda/lib/python3.11/site-packages (from httpx>=0.25.0->jupyterlab) (4.6.2.post1)\n",
      "Requirement already satisfied: certifi in ./.conda/lib/python3.11/site-packages (from httpx>=0.25.0->jupyterlab) (2024.8.30)\n",
      "Requirement already satisfied: httpcore==1.* in ./.conda/lib/python3.11/site-packages (from httpx>=0.25.0->jupyterlab) (1.0.7)\n",
      "Requirement already satisfied: idna in ./.conda/lib/python3.11/site-packages (from httpx>=0.25.0->jupyterlab) (3.10)\n",
      "Requirement already satisfied: h11<0.15,>=0.13 in ./.conda/lib/python3.11/site-packages (from httpcore==1.*->httpx>=0.25.0->jupyterlab) (0.14.0)\n",
      "Requirement already satisfied: comm>=0.1.1 in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab) (0.2.2)\n",
      "Requirement already satisfied: debugpy>=1.6.5 in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab) (1.8.9)\n",
      "Requirement already satisfied: ipython>=7.23.1 in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab) (8.30.0)\n",
      "Requirement already satisfied: jupyter-client>=6.1.12 in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab) (8.6.3)\n",
      "Requirement already satisfied: matplotlib-inline>=0.1 in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab) (0.1.7)\n",
      "Requirement already satisfied: nest-asyncio in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab) (1.6.0)\n",
      "Requirement already satisfied: psutil in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab) (6.1.0)\n",
      "Requirement already satisfied: pyzmq>=24 in ./.conda/lib/python3.11/site-packages (from ipykernel>=6.5.0->jupyterlab) (26.2.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in ./.conda/lib/python3.11/site-packages (from jinja2>=3.0.3->jupyterlab) (2.1.5)\n",
      "Requirement already satisfied: platformdirs>=2.5 in ./.conda/lib/python3.11/site-packages (from jupyter-core->jupyterlab) (4.3.6)\n",
      "Requirement already satisfied: argon2-cffi>=21.1 in ./.conda/lib/python3.11/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab) (23.1.0)\n",
      "Requirement already satisfied: jupyter-events>=0.9.0 in ./.conda/lib/python3.11/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab) (0.10.0)\n",
      "Requirement already satisfied: jupyter-server-terminals>=0.4.4 in ./.conda/lib/python3.11/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab) (0.5.3)\n",
      "Requirement already satisfied: nbconvert>=6.4.4 in ./.conda/lib/python3.11/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab) (7.16.4)\n",
      "Requirement already satisfied: nbformat>=5.3.0 in ./.conda/lib/python3.11/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab) (5.10.4)\n",
      "Requirement already satisfied: overrides>=5.0 in ./.conda/lib/python3.11/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab) (7.7.0)\n",
      "Requirement already satisfied: prometheus-client>=0.9 in ./.conda/lib/python3.11/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab) (0.21.0)\n",
      "Requirement already satisfied: send2trash>=1.8.2 in ./.conda/lib/python3.11/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab) (1.8.3)\n",
      "Requirement already satisfied: terminado>=0.8.3 in ./.conda/lib/python3.11/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab) (0.18.1)\n",
      "Requirement already satisfied: websocket-client>=1.7 in ./.conda/lib/python3.11/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab) (1.8.0)\n",
      "Requirement already satisfied: babel>=2.10 in ./.conda/lib/python3.11/site-packages (from jupyterlab-server<3,>=2.27.1->jupyterlab) (2.16.0)\n",
      "Requirement already satisfied: json5>=0.9.0 in ./.conda/lib/python3.11/site-packages (from jupyterlab-server<3,>=2.27.1->jupyterlab) (0.10.0)\n",
      "Requirement already satisfied: jsonschema>=4.18.0 in ./.conda/lib/python3.11/site-packages (from jupyterlab-server<3,>=2.27.1->jupyterlab) (4.23.0)\n",
      "Requirement already satisfied: requests>=2.31 in ./.conda/lib/python3.11/site-packages (from jupyterlab-server<3,>=2.27.1->jupyterlab) (2.32.3)\n",
      "Requirement already satisfied: sniffio>=1.1 in ./.conda/lib/python3.11/site-packages (from anyio->httpx>=0.25.0->jupyterlab) (1.3.1)\n",
      "Requirement already satisfied: argon2-cffi-bindings in ./.conda/lib/python3.11/site-packages (from argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->jupyterlab) (21.2.0)\n",
      "Requirement already satisfied: decorator in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (5.1.1)\n",
      "Requirement already satisfied: jedi>=0.16 in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (0.19.2)\n",
      "Requirement already satisfied: pexpect>4.3 in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (4.9.0)\n",
      "Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (3.0.48)\n",
      "Requirement already satisfied: pygments>=2.4.0 in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (2.18.0)\n",
      "Requirement already satisfied: stack_data in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (0.6.2)\n",
      "Requirement already satisfied: typing_extensions>=4.6 in ./.conda/lib/python3.11/site-packages (from ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (4.12.2)\n",
      "Requirement already satisfied: attrs>=22.2.0 in ./.conda/lib/python3.11/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->jupyterlab) (24.2.0)\n",
      "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in ./.conda/lib/python3.11/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->jupyterlab) (2024.10.1)\n",
      "Requirement already satisfied: referencing>=0.28.4 in ./.conda/lib/python3.11/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->jupyterlab) (0.35.1)\n",
      "Requirement already satisfied: rpds-py>=0.7.1 in ./.conda/lib/python3.11/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->jupyterlab) (0.21.0)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in ./.conda/lib/python3.11/site-packages (from jupyter-client>=6.1.12->ipykernel>=6.5.0->jupyterlab) (2.9.0.post0)\n",
      "Requirement already satisfied: python-json-logger>=2.0.4 in ./.conda/lib/python3.11/site-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (2.0.7)\n",
      "Requirement already satisfied: pyyaml>=5.3 in ./.conda/lib/python3.11/site-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (6.0.2)\n",
      "Requirement already satisfied: rfc3339-validator in ./.conda/lib/python3.11/site-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (0.1.4)\n",
      "Requirement already satisfied: rfc3986-validator>=0.1.1 in ./.conda/lib/python3.11/site-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (0.1.1)\n",
      "Requirement already satisfied: beautifulsoup4 in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab) (4.12.3)\n",
      "Requirement already satisfied: bleach!=5.0.0 in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab) (6.2.0)\n",
      "Requirement already satisfied: defusedxml in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab) (0.7.1)\n",
      "Requirement already satisfied: jupyterlab-pygments in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab) (0.3.0)\n",
      "Requirement already satisfied: mistune<4,>=2.0.3 in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab) (3.0.2)\n",
      "Requirement already satisfied: nbclient>=0.5.0 in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab) (0.10.1)\n",
      "Requirement already satisfied: pandocfilters>=1.4.1 in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab) (1.5.1)\n",
      "Requirement already satisfied: tinycss2 in ./.conda/lib/python3.11/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab) (1.4.0)\n",
      "Requirement already satisfied: fastjsonschema>=2.15 in ./.conda/lib/python3.11/site-packages (from nbformat>=5.3.0->jupyter-server<3,>=2.4.0->jupyterlab) (2.21.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in ./.conda/lib/python3.11/site-packages (from requests>=2.31->jupyterlab-server<3,>=2.27.1->jupyterlab) (3.4.0)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in ./.conda/lib/python3.11/site-packages (from requests>=2.31->jupyterlab-server<3,>=2.27.1->jupyterlab) (2.2.3)\n",
      "Requirement already satisfied: ptyprocess in ./.conda/lib/python3.11/site-packages (from terminado>=0.8.3->jupyter-server<3,>=2.4.0->jupyterlab) (0.7.0)\n",
      "Requirement already satisfied: webencodings in ./.conda/lib/python3.11/site-packages (from bleach!=5.0.0->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab) (0.5.1)\n",
      "Requirement already satisfied: parso<0.9.0,>=0.8.4 in ./.conda/lib/python3.11/site-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (0.8.4)\n",
      "Requirement already satisfied: fqdn in ./.conda/lib/python3.11/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (1.5.1)\n",
      "Requirement already satisfied: isoduration in ./.conda/lib/python3.11/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (20.11.0)\n",
      "Requirement already satisfied: jsonpointer>1.13 in ./.conda/lib/python3.11/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (3.0.0)\n",
      "Requirement already satisfied: uri-template in ./.conda/lib/python3.11/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (1.3.0)\n",
      "Requirement already satisfied: webcolors>=24.6.0 in ./.conda/lib/python3.11/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (24.11.1)\n",
      "Requirement already satisfied: wcwidth in ./.conda/lib/python3.11/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (0.2.13)\n",
      "Requirement already satisfied: six>=1.5 in ./.conda/lib/python3.11/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=6.5.0->jupyterlab) (1.16.0)\n",
      "Requirement already satisfied: cffi>=1.0.1 in ./.conda/lib/python3.11/site-packages (from argon2-cffi-bindings->argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->jupyterlab) (1.17.1)\n",
      "Requirement already satisfied: soupsieve>1.2 in ./.conda/lib/python3.11/site-packages (from beautifulsoup4->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab) (2.6)\n",
      "Requirement already satisfied: executing>=1.2.0 in ./.conda/lib/python3.11/site-packages (from stack_data->ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (2.1.0)\n",
      "Requirement already satisfied: asttokens>=2.1.0 in ./.conda/lib/python3.11/site-packages (from stack_data->ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (2.4.1)\n",
      "Requirement already satisfied: pure-eval in ./.conda/lib/python3.11/site-packages (from stack_data->ipython>=7.23.1->ipykernel>=6.5.0->jupyterlab) (0.2.3)\n",
      "Requirement already satisfied: pycparser in ./.conda/lib/python3.11/site-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->jupyterlab) (2.22)\n",
      "Requirement already satisfied: arrow>=0.15.0 in ./.conda/lib/python3.11/site-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (1.3.0)\n",
      "Requirement already satisfied: types-python-dateutil>=2.8.10 in ./.conda/lib/python3.11/site-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->jupyterlab) (2.9.0.20241003)\n",
      "Looking in indexes: http://mirrors.tencentyun.com/pypi/simple\n",
      "Collecting opencv-python\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/3f/a4/d2537f47fd7fcfba966bd806e3ec18e7ee1681056d4b0a9c8d983983e4d5/opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (62.5 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.5/62.5 MB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: numpy>=1.21.2 in ./.conda/lib/python3.11/site-packages (from opencv-python) (2.1.3)\n",
      "Installing collected packages: opencv-python\n",
      "Successfully installed opencv-python-4.10.0.84\n",
      "Looking in indexes: http://mirrors.tencentyun.com/pypi/simple\n",
      "Collecting matplotlib\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/13/53/b178d51478109f7a700edc94757dd07112e9a0c7a158653b99434b74f9fb/matplotlib-3.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.3/8.3 MB\u001b[0m \u001b[31m8.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting contourpy>=1.0.1 (from matplotlib)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/85/fc/7fa5d17daf77306840a4e84668a48ddff09e6bc09ba4e37e85ffc8e4faa3/contourpy-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (326 kB)\n",
      "Collecting cycler>=0.10 (from matplotlib)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl (8.3 kB)\n",
      "Collecting fonttools>=4.22.0 (from matplotlib)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/47/2b/9bf7527260d265281dd812951aa22f3d1c331bcc91e86e7038dc6b9737cb/fonttools-4.55.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.9 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting kiwisolver>=1.3.1 (from matplotlib)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/a7/4b/2db7af3ed3af7c35f388d5f53c28e155cd402a55432d800c543dc6deb731/kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.4 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m-:--:--\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: numpy>=1.23 in ./.conda/lib/python3.11/site-packages (from matplotlib) (2.1.3)\n",
      "Requirement already satisfied: packaging>=20.0 in ./.conda/lib/python3.11/site-packages (from matplotlib) (24.2)\n",
      "Requirement already satisfied: pillow>=8 in ./.conda/lib/python3.11/site-packages (from matplotlib) (10.2.0)\n",
      "Collecting pyparsing>=2.3.1 (from matplotlib)\n",
      "  Downloading http://mirrors.tencentyun.com/pypi/packages/be/ec/2eb3cd785efd67806c46c13a17339708ddc346cbb684eade7a6e6f79536a/pyparsing-3.2.0-py3-none-any.whl (106 kB)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in ./.conda/lib/python3.11/site-packages (from matplotlib) (2.9.0.post0)\n",
      "Requirement already satisfied: six>=1.5 in ./.conda/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n",
      "Installing collected packages: pyparsing, kiwisolver, fonttools, cycler, contourpy, matplotlib\n",
      "Successfully installed contourpy-1.3.1 cycler-0.12.1 fonttools-4.55.0 kiwisolver-1.4.7 matplotlib-3.9.3 pyparsing-3.2.0\n",
      "Looking in indexes: http://mirrors.tencentyun.com/pypi/simple\n",
      "\u001b[31mERROR: Could not find a version that satisfies the requirement moxing (from versions: none)\u001b[0m\u001b[31m\n",
      "\u001b[0m\u001b[31mERROR: No matching distribution found for moxing\u001b[0m\u001b[31m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install jupyter_contrib_nbextensions\n",
    "!pip install jupyterlab\n",
    "!pip install opencv-python\n",
    "!pip install matplotlib\n",
    "!pip install moxing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import subprocess as sp \n",
    "from collections import deque\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.multiprocessing as _mp\n",
    "from torch.distributions import Categorical\n",
    "import torch.multiprocessing as mp\n",
    "from nes_py.wrappers import JoypadSpace\n",
    "import gym_super_mario_bros\n",
    "from gym.spaces import Box\n",
    "from gym import Wrapper\n",
    "from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT, RIGHT_ONLY\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "\n",
    "# import moxing as mox"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt={\n",
    "    \"world\": 1,                # 可选大关：1,2,3,4,5,6,7,8\n",
    "    \"stage\": 1,                # 可选小关：1,2,3,4 \n",
    "    \"action_type\": \"simple\",   # 动作类别：\"simple\"，\"right_only\", \"complex\"\n",
    "    'lr': 1e-4,                # 建议学习率：1e-3，1e-4, 1e-5，7e-5\n",
    "    'gamma': 0.9,              # 奖励折扣\n",
    "    'tau': 1.0,                # GAE参数\n",
    "    'beta': 0.01,              # 熵系数\n",
    "    'epsilon': 0.2,            # PPO的Clip系数\n",
    "    'batch_size': 16,          # 经验回放的batch_size\n",
    "    'max_episode':10,          # 最大训练局数\n",
    "    'num_epochs': 10,          # 每条经验回放次数\n",
    "    \"num_local_steps\": 512,    # 每局的最大步数\n",
    "    \"num_processes\": 8,        # 训练进程数，一般等于训练机核心数\n",
    "    \"save_interval\": 5,        # 每{}局保存一次模型\n",
    "    \"log_path\": \"./log\",       # 日志保存路径\n",
    "    \"saved_path\": \"./model\",   # 训练模型保存路径\n",
    "    \"pretrain_model\": True,    # 是否加载预训练模型，目前只提供1-1关卡的预训练模型，其他需要从零开始训练\n",
    "    \"episode\":5\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建环境\n",
    "def create_train_env(world, stage, actions, output_path=None):\n",
    "    # 创建基础环境\n",
    "    env = gym_super_mario_bros.make(\"SuperMarioBros-{}-{}-v0\".format(world, stage))\n",
    "\n",
    "    env = JoypadSpace(env, actions)\n",
    "    # 对环境自定义\n",
    "    env = CustomReward(env, world, stage, monitor=None)\n",
    "    env = CustomSkipFrame(env)\n",
    "    return env\n",
    "# 对原始环境进行修改，以获得更好的训练效果\n",
    "class CustomReward(Wrapper):\n",
    "    def __init__(self, env=None, world=None, stage=None, monitor=None):\n",
    "        super(CustomReward, self).__init__(env)\n",
    "        self.observation_space = Box(low=0, high=255, shape=(1, 84, 84))\n",
    "        self.curr_score = 0\n",
    "        self.current_x = 40\n",
    "        self.world = world\n",
    "        self.stage = stage\n",
    "        if monitor:\n",
    "            self.monitor = monitor\n",
    "        else:\n",
    "            self.monitor = None\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, done, info = self.env.step(action)\n",
    "        if self.monitor:\n",
    "            self.monitor.record(state)\n",
    "        state = process_frame(state)\n",
    "        reward += (info[\"score\"] - self.curr_score) / 40.\n",
    "        self.curr_score = info[\"score\"]\n",
    "        if done:\n",
    "            if info[\"flag_get\"]:\n",
    "                reward += 50\n",
    "            else:\n",
    "                reward -= 50\n",
    "        if self.world == 7 and self.stage == 4:\n",
    "            if (506 <= info[\"x_pos\"] <= 832 and info[\"y_pos\"] > 127) or (\n",
    "                    832 < info[\"x_pos\"] <= 1064 and info[\"y_pos\"] < 80) or (\n",
    "                    1113 < info[\"x_pos\"] <= 1464 and info[\"y_pos\"] < 191) or (\n",
    "                    1579 < info[\"x_pos\"] <= 1943 and info[\"y_pos\"] < 191) or (\n",
    "                    1946 < info[\"x_pos\"] <= 1964 and info[\"y_pos\"] >= 191) or (\n",
    "                    1984 < info[\"x_pos\"] <= 2060 and (info[\"y_pos\"] >= 191 or info[\"y_pos\"] < 127)) or (\n",
    "                    2114 < info[\"x_pos\"] < 2440 and info[\"y_pos\"] < 191) or info[\"x_pos\"] < self.current_x - 500:\n",
    "                reward -= 50\n",
    "                done = True\n",
    "        if self.world == 4 and self.stage == 4:\n",
    "            if (info[\"x_pos\"] <= 1500 and info[\"y_pos\"] < 127) or (\n",
    "                    1588 <= info[\"x_pos\"] < 2380 and info[\"y_pos\"] >= 127):\n",
    "                reward = -50\n",
    "                done = True\n",
    "\n",
    "        self.current_x = info[\"x_pos\"]\n",
    "        return state, reward / 10., done, info\n",
    "\n",
    "    def reset(self):\n",
    "        self.curr_score = 0\n",
    "        self.current_x = 40\n",
    "        return process_frame(self.env.reset())\n",
    "class MultipleEnvironments:\n",
    "    def __init__(self, world, stage, action_type, num_envs, output_path=None):\n",
    "        self.agent_conns, self.env_conns = zip(*[mp.Pipe() for _ in range(num_envs)])\n",
    "        if action_type == \"right_only\":\n",
    "            actions = RIGHT_ONLY\n",
    "        elif action_type == \"simple\":\n",
    "            actions = SIMPLE_MOVEMENT\n",
    "        else:\n",
    "            actions = COMPLEX_MOVEMENT\n",
    "        self.envs = [create_train_env(world, stage, actions, output_path=output_path) for _ in range(num_envs)]\n",
    "        self.num_states = self.envs[0].observation_space.shape[0]\n",
    "        self.num_actions = len(actions)\n",
    "        for index in range(num_envs):\n",
    "            process = mp.Process(target=self.run, args=(index,))\n",
    "            process.start()\n",
    "            self.env_conns[index].close()\n",
    "\n",
    "    def run(self, index):\n",
    "        self.agent_conns[index].close()\n",
    "        while True:\n",
    "            request, action = self.env_conns[index].recv()\n",
    "            if request == \"step\":\n",
    "                self.env_conns[index].send(self.envs[index].step(action.item()))\n",
    "            elif request == \"reset\":\n",
    "                self.env_conns[index].send(self.envs[index].reset())\n",
    "            else:\n",
    "                raise NotImplementedError\n",
    "def process_frame(frame):\n",
    "    if frame is not None:\n",
    "        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)\n",
    "        frame = cv2.resize(frame, (84, 84))[None, :, :] / 255.\n",
    "        return frame\n",
    "    else:\n",
    "        return np.zeros((1, 84, 84))\n",
    "class CustomSkipFrame(Wrapper):\n",
    "    def __init__(self, env, skip=4):\n",
    "        super(CustomSkipFrame, self).__init__(env)\n",
    "        self.observation_space = Box(low=0, high=255, shape=(skip, 84, 84))\n",
    "        self.skip = skip\n",
    "        self.states = np.zeros((skip, 84, 84), dtype=np.float32)\n",
    "\n",
    "    def step(self, action):\n",
    "        total_reward = 0\n",
    "        last_states = []\n",
    "        for i in range(self.skip):\n",
    "            state, reward, done, info = self.env.step(action)\n",
    "            total_reward += reward\n",
    "            if i >= self.skip / 2:\n",
    "                last_states.append(state)\n",
    "            if done:\n",
    "                self.reset()\n",
    "                return self.states[None, :, :, :].astype(np.float32), total_reward, done, info\n",
    "        max_state = np.max(np.concatenate(last_states, 0), 0)\n",
    "        self.states[:-1] = self.states[1:]\n",
    "        self.states[-1] = max_state\n",
    "        return self.states[None, :, :, :].astype(np.float32), total_reward, done, info\n",
    "\n",
    "    def reset(self):\n",
    "        state = self.env.reset()\n",
    "        self.states = np.concatenate([state for _ in range(self.skip)], 0)\n",
    "        return self.states[None, :, :, :].astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self, num_inputs, num_actions):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2, padding=1)\n",
    "        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)\n",
    "        self.linear = nn.Linear(32 * 6 * 6, 512)\n",
    "        self.critic_linear = nn.Linear(512, 1)\n",
    "        self.actor_linear = nn.Linear(512, num_actions)\n",
    "        self._initialize_weights()\n",
    "\n",
    "    def _initialize_weights(self):\n",
    "        for module in self.modules():\n",
    "            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):\n",
    "                nn.init.orthogonal_(module.weight, nn.init.calculate_gain('relu'))\n",
    "                nn.init.constant_(module.bias, 0)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = F.relu(self.conv4(x))\n",
    "        x = self.linear(x.view(x.size(0), -1))\n",
    "        return self.actor_linear(x), self.critic_linear(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(opt, global_model, num_states, num_actions,curr_episode):\n",
    "    print('start evalution !')\n",
    "    torch.manual_seed(123)\n",
    "    if opt['action_type'] == \"right\":\n",
    "        actions = RIGHT_ONLY\n",
    "    elif opt['action_type'] == \"simple\":\n",
    "        actions = SIMPLE_MOVEMENT\n",
    "    else:\n",
    "        actions = COMPLEX_MOVEMENT\n",
    "    env = create_train_env(opt['world'], opt['stage'], actions)\n",
    "    local_model = Net(num_states, num_actions)\n",
    "    if torch.cuda.is_available():\n",
    "        local_model.cuda()\n",
    "    local_model.eval()\n",
    "    state = torch.from_numpy(env.reset())\n",
    "    if torch.cuda.is_available():\n",
    "        state = state.cuda()\n",
    "    \n",
    "    plt.figure(figsize=(10,10))\n",
    "    img = plt.imshow(env.render(mode='rgb_array'))\n",
    "    \n",
    "    done=False\n",
    "    local_model.load_state_dict(global_model.state_dict()) #加载网络参数\\\n",
    "\n",
    "    while not done:\n",
    "        if torch.cuda.is_available():\n",
    "            state = state.cuda()\n",
    "        logits, value = local_model(state)\n",
    "        policy = F.softmax(logits, dim=1)\n",
    "        action = torch.argmax(policy).item()\n",
    "        state, reward, done, info = env.step(action)\n",
    "        state = torch.from_numpy(state)\n",
    "        \n",
    "        img.set_data(env.render(mode='rgb_array')) # just update the data\n",
    "        display.display(plt.gcf())\n",
    "        display.clear_output(wait=True)\n",
    "\n",
    "        if info[\"flag_get\"]:\n",
    "            print(\"flag getted in episode:{}!\".format(curr_episode))\n",
    "            torch.save(local_model.state_dict(),\n",
    "                       \"{}/ppo_super_mario_bros_{}_{}_{}\".format(opt['saved_path'], opt['world'], opt['stage'],curr_episode))\n",
    "            opt.update({'episode':curr_episode})\n",
    "            env.close()\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "    \n",
    "def train(opt):\n",
    "    #判断cuda是否可用\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(123)\n",
    "    else:\n",
    "        torch.manual_seed(123)\n",
    "    if os.path.isdir(opt['log_path']):\n",
    "        shutil.rmtree(opt['log_path'])\n",
    "\n",
    "    os.makedirs(opt['log_path'])\n",
    "    if not os.path.isdir(opt['saved_path']):\n",
    "        os.makedirs(opt['saved_path'])\n",
    "    mp = _mp.get_context(\"spawn\")\n",
    "    #创建环境\n",
    "    envs = MultipleEnvironments(opt['world'], opt['stage'], opt['action_type'], opt['num_processes'])\n",
    "    #创建模型\n",
    "    model = Net(envs.num_states, envs.num_actions)\n",
    "    if opt['pretrain_model']:\n",
    "        print('加载预训练模型')\n",
    "        if not os.path.exists(\"ppo_super_mario_bros_1_1_0\"):\n",
    "            mox.file.copy_parallel(\n",
    "                \"obs://modelarts-labs-bj4/course/modelarts/zjc_team/reinforcement_learning/ppo_mario/ppo_super_mario_bros_1_1_0\",\n",
    "                \"ppo_super_mario_bros_1_1_0\")\n",
    "        if torch.cuda.is_available():\n",
    "            model.load_state_dict(torch.load(\"ppo_super_mario_bros_1_1_0\"))\n",
    "            model.cuda()\n",
    "        else:\n",
    "            model.load_state_dict(torch.load(\"ppo_super_mario_bros_1_1_0\",torch.device('cpu')))\n",
    "    else:\n",
    "         model.cuda()\n",
    "    model.share_memory()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=opt['lr'])\n",
    "    #环境重置\n",
    "    [agent_conn.send((\"reset\", None)) for agent_conn in envs.agent_conns]\n",
    "    #接收当前状态\n",
    "    curr_states = [agent_conn.recv() for agent_conn in envs.agent_conns]\n",
    "    curr_states = torch.from_numpy(np.concatenate(curr_states, 0))\n",
    "    if torch.cuda.is_available():\n",
    "        curr_states = curr_states.cuda()\n",
    "    curr_episode = 0\n",
    "    #在最大局数内训练\n",
    "    while curr_episode<opt['max_episode']:\n",
    "        if curr_episode % opt['save_interval'] == 0 and curr_episode > 0:\n",
    "            torch.save(model.state_dict(),\n",
    "                       \"{}/ppo_super_mario_bros_{}_{}_{}\".format(opt['saved_path'], opt['world'], opt['stage'], curr_episode))\n",
    "        curr_episode += 1\n",
    "        old_log_policies = []\n",
    "        actions = []\n",
    "        values = []\n",
    "        states = []\n",
    "        rewards = []\n",
    "        dones = []\n",
    "        #一局内最大步数\n",
    "        for _ in range(opt['num_local_steps']):\n",
    "            states.append(curr_states)\n",
    "            logits, value = model(curr_states)\n",
    "            values.append(value.squeeze())\n",
    "            policy = F.softmax(logits, dim=1)\n",
    "            old_m = Categorical(policy)\n",
    "            action = old_m.sample()\n",
    "            actions.append(action)\n",
    "            old_log_policy = old_m.log_prob(action)\n",
    "            old_log_policies.append(old_log_policy)\n",
    "            #执行action\n",
    "            if torch.cuda.is_available():\n",
    "                [agent_conn.send((\"step\", act)) for agent_conn, act in zip(envs.agent_conns, action.cpu())]\n",
    "            else:\n",
    "                [agent_conn.send((\"step\", act)) for agent_conn, act in zip(envs.agent_conns, action)]\n",
    "            state, reward, done, info = zip(*[agent_conn.recv() for agent_conn in envs.agent_conns])\n",
    "            state = torch.from_numpy(np.concatenate(state, 0))\n",
    "            if torch.cuda.is_available():\n",
    "                state = state.cuda()\n",
    "                reward = torch.cuda.FloatTensor(reward)\n",
    "                done = torch.cuda.FloatTensor(done)\n",
    "            else:\n",
    "                reward = torch.FloatTensor(reward)\n",
    "                done = torch.FloatTensor(done)\n",
    "            rewards.append(reward)\n",
    "            dones.append(done)\n",
    "            curr_states = state\n",
    "\n",
    "        _, next_value, = model(curr_states)\n",
    "        next_value = next_value.squeeze()\n",
    "        old_log_policies = torch.cat(old_log_policies).detach()\n",
    "        actions = torch.cat(actions)\n",
    "        values = torch.cat(values).detach()\n",
    "        states = torch.cat(states)\n",
    "        gae = 0\n",
    "        R = []\n",
    "        #gae计算\n",
    "        for value, reward, done in list(zip(values, rewards, dones))[::-1]:\n",
    "            gae = gae * opt['gamma'] * opt['tau']\n",
    "            gae = gae + reward + opt['gamma'] * next_value.detach() * (1 - done) - value.detach()\n",
    "            next_value = value\n",
    "            R.append(gae + value)\n",
    "        R = R[::-1]\n",
    "        R = torch.cat(R).detach()\n",
    "        advantages = R - values\n",
    "        #策略更新\n",
    "        for i in range(opt['num_epochs']):\n",
    "            indice = torch.randperm(opt['num_local_steps'] * opt['num_processes'])\n",
    "            for j in range(opt['batch_size']):\n",
    "                batch_indices = indice[\n",
    "                                int(j * (opt['num_local_steps'] * opt['num_processes'] / opt['batch_size'])): int((j + 1) * (\n",
    "                                        opt['num_local_steps'] * opt['num_processes'] / opt['batch_size']))]\n",
    "                logits, value = model(states[batch_indices])\n",
    "                new_policy = F.softmax(logits, dim=1)\n",
    "                new_m = Categorical(new_policy)\n",
    "                new_log_policy = new_m.log_prob(actions[batch_indices])\n",
    "                ratio = torch.exp(new_log_policy - old_log_policies[batch_indices])\n",
    "                actor_loss = -torch.mean(torch.min(ratio * advantages[batch_indices],\n",
    "                                                   torch.clamp(ratio, 1.0 - opt['epsilon'], 1.0 + opt['epsilon']) *\n",
    "                                                   advantages[\n",
    "                                                       batch_indices]))\n",
    "                critic_loss = F.smooth_l1_loss(R[batch_indices], value.squeeze())\n",
    "                entropy_loss = torch.mean(new_m.entropy())\n",
    "                #损失函数包含三个部分：actor损失，critic损失，和动作entropy损失\n",
    "                total_loss = actor_loss + critic_loss - opt['beta'] * entropy_loss\n",
    "                optimizer.zero_grad()\n",
    "                total_loss.backward()\n",
    "                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n",
    "                optimizer.step()\n",
    "        print(\"Episode: {}. Total loss: {}\".format(curr_episode, total_loss))\n",
    "        \n",
    "        finish=False\n",
    "        for i in range(opt[\"num_processes\"]):\n",
    "            if info[i][\"flag_get\"]:\n",
    "                finish=evaluation(opt, model,envs.num_states, envs.num_actions,curr_episode)\n",
    "                if finish:\n",
    "                    break\n",
    "        if finish:\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "ename": "OverflowError",
     "evalue": "Python integer 1024 out of bounds for uint8",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOverflowError\u001b[0m                             Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[9], line 62\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(opt)\u001b[0m\n\u001b[1;32m     60\u001b[0m mp \u001b[38;5;241m=\u001b[39m _mp\u001b[38;5;241m.\u001b[39mget_context(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mspawn\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m     61\u001b[0m \u001b[38;5;66;03m#创建环境\u001b[39;00m\n\u001b[0;32m---> 62\u001b[0m envs \u001b[38;5;241m=\u001b[39m \u001b[43mMultipleEnvironments\u001b[49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mworld\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mstage\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43maction_type\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mnum_processes\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     63\u001b[0m \u001b[38;5;66;03m#创建模型\u001b[39;00m\n\u001b[1;32m     64\u001b[0m model \u001b[38;5;241m=\u001b[39m Net(envs\u001b[38;5;241m.\u001b[39mnum_states, envs\u001b[38;5;241m.\u001b[39mnum_actions)\n",
      "Cell \u001b[0;32mIn[4], line 69\u001b[0m, in \u001b[0;36mMultipleEnvironments.__init__\u001b[0;34m(self, world, stage, action_type, num_envs, output_path)\u001b[0m\n\u001b[1;32m     67\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     68\u001b[0m     actions \u001b[38;5;241m=\u001b[39m COMPLEX_MOVEMENT\n\u001b[0;32m---> 69\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menvs \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[43mcreate_train_env\u001b[49m\u001b[43m(\u001b[49m\u001b[43mworld\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mactions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_path\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mnum_envs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m     70\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menvs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mobservation_space\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m     71\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_actions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(actions)\n",
      "Cell \u001b[0;32mIn[4], line 69\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     67\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     68\u001b[0m     actions \u001b[38;5;241m=\u001b[39m COMPLEX_MOVEMENT\n\u001b[0;32m---> 69\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menvs \u001b[38;5;241m=\u001b[39m [\u001b[43mcreate_train_env\u001b[49m\u001b[43m(\u001b[49m\u001b[43mworld\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mactions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_path\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_envs)]\n\u001b[1;32m     70\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menvs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mobservation_space\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m     71\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_actions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(actions)\n",
      "Cell \u001b[0;32mIn[4], line 4\u001b[0m, in \u001b[0;36mcreate_train_env\u001b[0;34m(world, stage, actions, output_path)\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_train_env\u001b[39m(world, stage, actions, output_path\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m      3\u001b[0m     \u001b[38;5;66;03m# 创建基础环境\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m     env \u001b[38;5;241m=\u001b[39m \u001b[43mgym_super_mario_bros\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mSuperMarioBros-\u001b[39;49m\u001b[38;5;132;43;01m{}\u001b[39;49;00m\u001b[38;5;124;43m-\u001b[39;49m\u001b[38;5;132;43;01m{}\u001b[39;49;00m\u001b[38;5;124;43m-v0\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mformat\u001b[49m\u001b[43m(\u001b[49m\u001b[43mworld\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstage\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      6\u001b[0m     env \u001b[38;5;241m=\u001b[39m JoypadSpace(env, actions)\n\u001b[1;32m      7\u001b[0m     \u001b[38;5;66;03m# 对环境自定义\u001b[39;00m\n",
      "File \u001b[0;32m~/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/gym/envs/registration.py:640\u001b[0m, in \u001b[0;36mmake\u001b[0;34m(id, max_episode_steps, autoreset, apply_api_compatibility, disable_env_checker, **kwargs)\u001b[0m\n\u001b[1;32m    637\u001b[0m     render_mode \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    639\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 640\u001b[0m     env \u001b[38;5;241m=\u001b[39m \u001b[43menv_creator\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    641\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    642\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m    643\u001b[0m         \u001b[38;5;28mstr\u001b[39m(e)\u001b[38;5;241m.\u001b[39mfind(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgot an unexpected keyword argument \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrender_mode\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m    644\u001b[0m         \u001b[38;5;129;01mand\u001b[39;00m apply_human_rendering\n\u001b[1;32m    645\u001b[0m     ):\n",
      "File \u001b[0;32m~/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/gym_super_mario_bros/smb_env.py:52\u001b[0m, in \u001b[0;36mSuperMarioBrosEnv.__init__\u001b[0;34m(self, rom_mode, lost_levels, target)\u001b[0m\n\u001b[1;32m     50\u001b[0m rom \u001b[38;5;241m=\u001b[39m rom_path(lost_levels, rom_mode)\n\u001b[1;32m     51\u001b[0m \u001b[38;5;66;03m# initialize the super object with the ROM path\u001b[39;00m\n\u001b[0;32m---> 52\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mSuperMarioBrosEnv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mrom\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     53\u001b[0m \u001b[38;5;66;03m# set the target world, stage, and area variables\u001b[39;00m\n\u001b[1;32m     54\u001b[0m target \u001b[38;5;241m=\u001b[39m decode_target(target, lost_levels)\n",
      "File \u001b[0;32m~/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/nes_py/nes_env.py:126\u001b[0m, in \u001b[0;36mNESEnv.__init__\u001b[0;34m(self, rom_path)\u001b[0m\n\u001b[1;32m    124\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mROM has trainer. trainer is not supported.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    125\u001b[0m \u001b[38;5;66;03m# try to read the PRG ROM and raise a value error if it fails\u001b[39;00m\n\u001b[0;32m--> 126\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mrom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprg_rom\u001b[49m\n\u001b[1;32m    127\u001b[0m \u001b[38;5;66;03m# try to read the CHR ROM and raise a value error if it fails\u001b[39;00m\n\u001b[1;32m    128\u001b[0m _ \u001b[38;5;241m=\u001b[39m rom\u001b[38;5;241m.\u001b[39mchr_rom\n",
      "File \u001b[0;32m~/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/nes_py/_rom.py:204\u001b[0m, in \u001b[0;36mROM.prg_rom\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    202\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Return the PRG ROM of the ROM file.\"\"\"\u001b[39;00m\n\u001b[1;32m    203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 204\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mraw_data[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprg_rom_start:\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprg_rom_stop\u001b[49m]\n\u001b[1;32m    205\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m:\n\u001b[1;32m    206\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfailed to read PRG-ROM on ROM.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "File \u001b[0;32m~/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/nes_py/_rom.py:198\u001b[0m, in \u001b[0;36mROM.prg_rom_stop\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    195\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m    196\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprg_rom_stop\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    197\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"The exclusive stopping index of the PRG ROM.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 198\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprg_rom_start \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprg_rom_size\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\n",
      "\u001b[0;31mOverflowError\u001b[0m: Python integer 1024 out of bounds for uint8"
     ]
    }
   ],
   "source": [
    "train(opt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'train' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrain\u001b[49m(opt)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'train' is not defined"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def infer(opt):\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(123)\n",
    "    else:\n",
    "        torch.manual_seed(123)\n",
    "    if opt['action_type'] == \"right\":\n",
    "        actions = RIGHT_ONLY\n",
    "    elif opt['action_type'] == \"simple\":\n",
    "        actions = SIMPLE_MOVEMENT\n",
    "    else:\n",
    "        actions = COMPLEX_MOVEMENT\n",
    "    env = create_train_env(opt['world'], opt['stage'], actions)\n",
    "    model = Net(env.observation_space.shape[0], len(actions))\n",
    "    if torch.cuda.is_available():\n",
    "        model.load_state_dict(torch.load(\"{}/ppo_super_mario_bros_{}_{}_{}\".format(opt['saved_path'],opt['world'], opt['stage'],opt['episode'])))\n",
    "        model.cuda()\n",
    "    else:\n",
    "        model.load_state_dict(torch.load(\"{}/ppo_super_mario_bros_{}_{}_{}\".format(opt['saved_path'], opt['world'], opt['stage'],opt['episode']),\n",
    "                                         map_location=torch.device('cpu')))\n",
    "    model.eval()\n",
    "    state = torch.from_numpy(env.reset())\n",
    "    \n",
    "    plt.figure(figsize=(10,10))\n",
    "    img = plt.imshow(env.render(mode='rgb_array'))\n",
    "    \n",
    "    while True:\n",
    "        if torch.cuda.is_available():\n",
    "            state = state.cuda()\n",
    "        logits, value = model(state)\n",
    "        policy = F.softmax(logits, dim=1)\n",
    "        action = torch.argmax(policy).item()\n",
    "        state, reward, done, info = env.step(action)\n",
    "        state = torch.from_numpy(state)\n",
    "        \n",
    "        img.set_data(env.render(mode='rgb_array')) # just update the data\n",
    "        display.display(plt.gcf())\n",
    "        display.clear_output(wait=True)\n",
    "        \n",
    "        if info[\"flag_get\"]:\n",
    "            print(\"World {} stage {} completed\".format(opt['world'], opt['stage']))\n",
    "            break\n",
    "            \n",
    "        if done and info[\"flag_get\"] is False:\n",
    "            print('Game Failed')\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/gym/envs/registration.py:555: UserWarning: \u001b[33mWARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.\u001b[0m\n",
      "  logger.warn(\n"
     ]
    },
    {
     "ename": "OverflowError",
     "evalue": "Python integer 1024 out of bounds for uint8",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOverflowError\u001b[0m                             Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43minfer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[7], line 12\u001b[0m, in \u001b[0;36minfer\u001b[0;34m(opt)\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     11\u001b[0m     actions \u001b[38;5;241m=\u001b[39m COMPLEX_MOVEMENT\n\u001b[0;32m---> 12\u001b[0m env \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_train_env\u001b[49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mworld\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mstage\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mactions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     13\u001b[0m model \u001b[38;5;241m=\u001b[39m Net(env\u001b[38;5;241m.\u001b[39mobservation_space\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;28mlen\u001b[39m(actions))\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available():\n",
      "Cell \u001b[0;32mIn[4], line 4\u001b[0m, in \u001b[0;36mcreate_train_env\u001b[0;34m(world, stage, actions, output_path)\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_train_env\u001b[39m(world, stage, actions, output_path\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m      3\u001b[0m     \u001b[38;5;66;03m# 创建基础环境\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m     env \u001b[38;5;241m=\u001b[39m \u001b[43mgym_super_mario_bros\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mSuperMarioBros-\u001b[39;49m\u001b[38;5;132;43;01m{}\u001b[39;49;00m\u001b[38;5;124;43m-\u001b[39;49m\u001b[38;5;132;43;01m{}\u001b[39;49;00m\u001b[38;5;124;43m-v0\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mformat\u001b[49m\u001b[43m(\u001b[49m\u001b[43mworld\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstage\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      6\u001b[0m     env \u001b[38;5;241m=\u001b[39m JoypadSpace(env, actions)\n\u001b[1;32m      7\u001b[0m     \u001b[38;5;66;03m# 对环境自定义\u001b[39;00m\n",
      "File \u001b[0;32m~/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/gym/envs/registration.py:640\u001b[0m, in \u001b[0;36mmake\u001b[0;34m(id, max_episode_steps, autoreset, apply_api_compatibility, disable_env_checker, **kwargs)\u001b[0m\n\u001b[1;32m    637\u001b[0m     render_mode \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    639\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 640\u001b[0m     env \u001b[38;5;241m=\u001b[39m \u001b[43menv_creator\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    641\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    642\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m    643\u001b[0m         \u001b[38;5;28mstr\u001b[39m(e)\u001b[38;5;241m.\u001b[39mfind(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgot an unexpected keyword argument \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrender_mode\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m    644\u001b[0m         \u001b[38;5;129;01mand\u001b[39;00m apply_human_rendering\n\u001b[1;32m    645\u001b[0m     ):\n",
      "File \u001b[0;32m~/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/gym_super_mario_bros/smb_env.py:52\u001b[0m, in \u001b[0;36mSuperMarioBrosEnv.__init__\u001b[0;34m(self, rom_mode, lost_levels, target)\u001b[0m\n\u001b[1;32m     50\u001b[0m rom \u001b[38;5;241m=\u001b[39m rom_path(lost_levels, rom_mode)\n\u001b[1;32m     51\u001b[0m \u001b[38;5;66;03m# initialize the super object with the ROM path\u001b[39;00m\n\u001b[0;32m---> 52\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mSuperMarioBrosEnv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mrom\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     53\u001b[0m \u001b[38;5;66;03m# set the target world, stage, and area variables\u001b[39;00m\n\u001b[1;32m     54\u001b[0m target \u001b[38;5;241m=\u001b[39m decode_target(target, lost_levels)\n",
      "File \u001b[0;32m~/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/nes_py/nes_env.py:126\u001b[0m, in \u001b[0;36mNESEnv.__init__\u001b[0;34m(self, rom_path)\u001b[0m\n\u001b[1;32m    124\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mROM has trainer. trainer is not supported.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    125\u001b[0m \u001b[38;5;66;03m# try to read the PRG ROM and raise a value error if it fails\u001b[39;00m\n\u001b[0;32m--> 126\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mrom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprg_rom\u001b[49m\n\u001b[1;32m    127\u001b[0m \u001b[38;5;66;03m# try to read the CHR ROM and raise a value error if it fails\u001b[39;00m\n\u001b[1;32m    128\u001b[0m _ \u001b[38;5;241m=\u001b[39m rom\u001b[38;5;241m.\u001b[39mchr_rom\n",
      "File \u001b[0;32m~/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/nes_py/_rom.py:204\u001b[0m, in \u001b[0;36mROM.prg_rom\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    202\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Return the PRG ROM of the ROM file.\"\"\"\u001b[39;00m\n\u001b[1;32m    203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 204\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mraw_data[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprg_rom_start:\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprg_rom_stop\u001b[49m]\n\u001b[1;32m    205\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m:\n\u001b[1;32m    206\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfailed to read PRG-ROM on ROM.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "File \u001b[0;32m~/code/reinforcement-learning-to-play-games/.conda/lib/python3.11/site-packages/nes_py/_rom.py:198\u001b[0m, in \u001b[0;36mROM.prg_rom_stop\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    195\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m    196\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprg_rom_stop\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    197\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"The exclusive stopping index of the PRG ROM.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 198\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprg_rom_start \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprg_rom_size\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\n",
      "\u001b[0;31mOverflowError\u001b[0m: Python integer 1024 out of bounds for uint8"
     ]
    }
   ],
   "source": [
    "infer(opt)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
